# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import json
import multiprocessing as mp
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import Any, Iterable, List, Optional, Union

import requests
import torch

from sglang.srt.utils import kill_process_tree
from sglang.test.runners import SRTRunner
from sglang.test.test_utils import (
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    CustomTestCase,
    calculate_rouge_l,
    is_in_ci,
    popen_launch_server,
)

PROMPTS = [
    "SGL is a",
    "AI is a field of computer science focused on",
    "Computer science is the study of",
    "Write a short story.",
    "What are the main components of a computer?",
]

MEM_FRACTION_STATIC = 0.8


class OperationType(Enum):
    LOAD = "load"
    UNLOAD = "unload"
    FORWARD = "forward"


@dataclass
class Operation:
    # Operation type, can be LOAD, UNLOAD, FORWARD
    type: OperationType
    # Data associated with the operation. Exact type varies depending on the operation
    data: Optional[Any]
    # If the operation is expected to fail, this is the error message to expect
    expected_error: Optional[str] = None
    # Because the logic for implicitly evicting LoRA adapters can be complicated, we explicitly
    # pass in LoRA adapters that should be implicitly evicted here
    expected_implicit_evictions: Optional[set[str]] = None


@dataclass
class TestCase:
    description: str
    base: str
    max_loras_per_batch: int
    all_adapters: List[str]
    op_sequence: List[Operation]
    initial_adapters: Optional[List[str]] = None
    enable_lora: Optional[bool] = None
    max_lora_rank: Optional[int] = None
    lora_target_modules: Optional[List] = None
    max_new_tokens: int = 32
    max_loaded_loras: Optional[int] = None


def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
    if not isinstance(adapters, list):
        adapters = [adapters]
    return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]


BASIC_TESTS = [
    TestCase(
        description="dynamic lora update with initial lora_paths",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=3,
        all_adapters=[
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            "pbevan11/llama-3.1-8b-ocr-correction",
        ],
        initial_adapters=[
            # Testing 3 supported lora-path formats.
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            {
                "lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
                "lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
                "pinned": False,
            },
        ],
        op_sequence=[
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
                expected_error="already loaded",
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                    ]
                ),
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
            ),
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                expected_error="already loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
                expected_error="already loaded",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                    ]
                ),
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                    ]
                ),
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    None,
                ),
            ),
        ],
    ),
    TestCase(
        description="dynamic lora update without initial lora_paths",
        base="meta-llama/Llama-3.1-8B-Instruct",
        enable_lora=True,
        max_lora_rank=256,
        lora_target_modules=["all"],
        max_loras_per_batch=4,
        all_adapters=[
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            "pbevan11/llama-3.1-8b-ocr-correction",
        ],
        op_sequence=[
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.LOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
            ),
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
                expected_error="already loaded",
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                        None,
                    ]
                ),
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        None,
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                        None,
                    ]
                ),
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(None),
            ),
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
                expected_error="already loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                expected_error="already loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
                expected_error="already loaded",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                        None,
                    ]
                ),
            ),
        ],
    ),
]
TARGET_MODULE_TESTS = [
    TestCase(
        description="Test explicitly specified lora-target-modules.",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=3,
        lora_target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        max_lora_rank=64,
        all_adapters=[
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",  # target_modules = q, k, v, o, gate, up, down
            "algoprog/fact-generation-llama-3.1-8b-instruct-lora",  # target_modules = q, k, v, o, gate
        ],
        initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
        op_sequence=[
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "algoprog/fact-generation-llama-3.1-8b-instruct-lora"
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                ),
                expected_error="never been loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "algoprog/fact-generation-llama-3.1-8b-instruct-lora",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        None,
                    ]
                ),
            ),
        ],
    ),
    TestCase(
        description="Test inferred lora-target-modules - start with larger adapter",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=3,
        max_lora_rank=64,
        all_adapters=[
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",  # target_modules = q, k, v, o, gate, up, down
            "algoprog/fact-generation-llama-3.1-8b-instruct-lora",  # target_modules = q, k, v, o, gate
        ],
        initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
        op_sequence=[
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "algoprog/fact-generation-llama-3.1-8b-instruct-lora"
                ),
                expected_error="never been loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "algoprog/fact-generation-llama-3.1-8b-instruct-lora",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        None,
                    ]
                ),
            ),
        ],
    ),
    TestCase(
        description="Test inferred lora-target-modules - start with smaller adapter",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=3,
        max_lora_rank=64,
        all_adapters=[
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",  # target_modules = q, k, v, o, gate, up, down
            "algoprog/fact-generation-llama-3.1-8b-instruct-lora",  # target_modules = q, k, v, o, gate
        ],
        initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
        op_sequence=[
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "algoprog/fact-generation-llama-3.1-8b-instruct-lora"
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                ),
                expected_error="never been loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                expected_error="incompatible",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "algoprog/fact-generation-llama-3.1-8b-instruct-lora",
                        None,
                    ]
                ),
            ),
        ],
    ),
]
MAX_LORA_RANK_TESTS = [
    TestCase(
        description="Test explicitly specified max-lora-rank.",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=3,
        max_lora_rank=32,
        all_adapters=[
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",  # r = 4
            "pbevan11/llama-3.1-8b-ocr-correction",  # r = 32
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",  # r = 256
        ],
        initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
        op_sequence=[
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
                expected_error="never been loaded",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
                expected_error="never been loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "pbevan11/llama-3.1-8b-ocr-correction",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        None,
                    ]
                ),
            ),
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
                expected_error="incompatible",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                ),
                expected_error="never been loaded",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "pbevan11/llama-3.1-8b-ocr-correction",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        None,
                    ]
                ),
            ),
        ],
    ),
    TestCase(
        description="test implicitly inferred max-lora-rank",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=3,
        all_adapters=[
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",  # r = 4
            "pbevan11/llama-3.1-8b-ocr-correction",  # r = 32
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",  # r = 256
        ],
        initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
        op_sequence=[
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
            ),
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
                expected_error="incompatible",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
                expected_error="never been loaded",
            ),
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                        None,
                    ]
                ),
            ),
        ],
    ),
]
MAX_LOADED_LORAS_TESTS = [
    TestCase(
        description="Test max_loaded_loras limit as well as implicit eviction and reloading",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=2,
        max_loaded_loras=2,
        all_adapters=[
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            "pbevan11/llama-3.1-8b-ocr-correction",
        ],
        initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
        op_sequence=[
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.LOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
                expected_implicit_evictions={
                    "philschmid/code-llama-3-1-8b-text-to-sql-lora"
                },
            ),
            # Implicitly load "philschmid/code-llama-3-1-8b-text-to-sql-lora"
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                    ]
                ),
                expected_implicit_evictions={"pbevan11/llama-3.1-8b-ocr-correction"},
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                    ]
                ),
            ),
            Operation(
                type=OperationType.LOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            # Implicitly load "pbevan11/llama-3.1-8b-ocr-correction" and make sure that "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
            # isn't implicitly unloaded even though it is LRU because it is needed for this forward pass
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                    ]
                ),
                expected_implicit_evictions={
                    "philschmid/code-llama-3-1-8b-text-to-sql-lora"
                },
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.LOAD,
                data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                    ]
                ),
                expected_implicit_evictions={
                    "pbevan11/llama-3.1-8b-ocr-correction",
                    "algoprog/fact-generation-llama-3.1-8b-instruct-lora",
                },
            ),
        ],
    ),
    TestCase(
        description="Test implicit eviction and reloading with pinned LoRA adapters",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=2,
        max_loaded_loras=2,
        all_adapters=[
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            "pbevan11/llama-3.1-8b-ocr-correction",
        ],
        initial_adapters=[
            {
                "lora_name": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                "lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                "pinned": True,
            }
        ],
        op_sequence=[
            Operation(
                type=OperationType.LOAD,
                data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            ),
            Operation(
                type=OperationType.LOAD,
                data="pbevan11/llama-3.1-8b-ocr-correction",
                expected_implicit_evictions={
                    "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
                },
            ),
            # Implicitly load "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                    ]
                ),
                expected_implicit_evictions={"pbevan11/llama-3.1-8b-ocr-correction"},
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
                    "lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
                    "pinned": True,
                },
                expected_error="starvation",
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
                    "lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
                    "pinned": True,
                },
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                        "pbevan11/llama-3.1-8b-ocr-correction",
                    ]
                ),
            ),
        ],
    ),
]
EVICTION_TESTS = [
    TestCase(
        description="dynamic lora update with evictions",
        base="meta-llama/Llama-3.1-8B-Instruct",
        max_loras_per_batch=2,
        all_adapters=[
            "lora1=philschmid/code-llama-3-1-8b-text-to-sql-lora",
            "lora2=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
            "lora3=pbevan11/llama-3.1-8b-ocr-correction",
        ],
        enable_lora=True,
        max_lora_rank=256,
        lora_target_modules=["all"],
        op_sequence=[
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "lora1",
                    "lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                    "pinned": True,
                },
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "lora2",
                    "lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                    "pinned": True,
                },
                expected_error="starvation",
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "lora2",
                    "lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
                    "pinned": False,
                },
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "lora3",
                    "lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
                    "pinned": False,
                },
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="lora1",
            ),
            Operation(
                type=OperationType.UNLOAD,
                data="lora3",
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "lora3",
                    "lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
                    "pinned": True,
                },
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "lora1",
                    "lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                    "pinned": True,
                },
                expected_error="starvation",
            ),
            Operation(
                type=OperationType.LOAD,
                data={
                    "lora_name": "lora1",
                    "lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
                    "pinned": False,
                },
            ),
            # pinned: lora3
            # unpinned: lora1, lora2
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "lora1",
                        "lora2",
                    ]
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "lora1",
                        "lora3",
                    ]
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "lora1",
                        "lora2",
                    ]
                ),
            ),
            Operation(
                type=OperationType.FORWARD,
                data=create_batch_data(
                    [
                        "lora1",
                        "lora2",
                        None,
                    ]
                ),
            ),
        ],
    ),
]

ALL_TESTS = (
    BASIC_TESTS
    + TARGET_MODULE_TESTS
    + MAX_LORA_RANK_TESTS
    + MAX_LOADED_LORAS_TESTS
    + EVICTION_TESTS
)


class LoRAUpdateTestSessionMode(Enum):
    ENGINE = "engine"
    SERVER = "server"


class LoRAUpdateTestSessionBase:
    """
    Base context manager for testing LoRA adapters.
    """

    def __init__(
        self,
        *,
        testcase: Optional[TestCase],
        model_path: str,
        lora_paths: List[Union[str, dict]],
        max_loras_per_batch: int,
        max_loaded_loras: Optional[int] = None,
        max_lora_rank: Optional[int],
        enable_lora: Optional[bool] = None,
        lora_target_modules: Optional[List[str]] = None,
        lora_backend: str = "csgmv",
        disable_cuda_graph: bool = False,
        cuda_graph_max_bs: int = 4,
    ):
        self.testcase = testcase
        self.model_path = model_path
        self.lora_paths = lora_paths
        self.max_lora_rank = max_lora_rank
        self.lora_target_modules = lora_target_modules
        self.max_loras_per_batch = max_loras_per_batch
        self.max_loaded_loras = max_loaded_loras
        self.lora_backend = lora_backend
        self.disable_cuda_graph = disable_cuda_graph
        self.cuda_graph_max_bs = cuda_graph_max_bs
        self.enable_lora = enable_lora

        self.expected_adapters = set()
        if self.lora_paths:
            for adapter in self.lora_paths:
                if isinstance(adapter, dict):
                    lora_name = adapter["lora_name"]
                elif "=" in adapter:
                    lora_name = adapter.split("=")[0]
                else:
                    lora_name = adapter
                self.expected_adapters.add(lora_name)

        self.handle = None  # Will be set in __enter__

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        # Don't suppress exceptions by default
        return False

    def load_lora_adapter(
        self,
        lora_name: str,
        lora_path: Optional[str] = None,
        expected_error: Optional[str] = None,
        expected_implicit_evictions: Optional[set[str]] = None,
    ):
        """
        Load a LoRA adapter by name and path.
        """
        raise NotImplementedError("Subclasses must implement load_lora_adapter")

    def unload_lora_adapter(self, lora_name: str):
        """
        Unload a LoRA adapter by name.
        """
        raise NotImplementedError("Subclasses must implement unload_lora_adapter")

    def forward(
        self,
        prompts: List[str],
        lora_paths: List[str],
        max_new_tokens: int = 32,
        expected_error: Optional[str] = None,
        expected_implicit_evictions: Optional[set[str]] = None,
    ):
        """
        Perform a batch forward pass with the current set of loaded LoRA adapters.
        """
        raise NotImplementedError("Subclasses must implement forward")


class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
    """
    Context manager for testing LoRA adapters with in-process engine.
    """

    def __enter__(self):
        # in-process runner
        self.handle = SRTRunner(
            model_path=self.model_path,
            model_type="generation",
            lora_paths=self.lora_paths,
            max_lora_rank=self.max_lora_rank,
            lora_target_modules=self.lora_target_modules,
            lora_backend=self.lora_backend,
            torch_dtype=torch.float16,
            mem_fraction_static=MEM_FRACTION_STATIC,
            max_loras_per_batch=self.max_loras_per_batch,
            max_loaded_loras=self.max_loaded_loras,
            disable_cuda_graph=self.disable_cuda_graph,
            cuda_graph_max_bs=self.cuda_graph_max_bs,
            enable_lora=self.enable_lora,
            disable_radix_cache=True,
        )
        self.handle.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.handle is not None:
            # delegate cleanup to SRTRunner
            return self.handle.__exit__(exc_type, exc_val, exc_tb)
        # don't suppress exceptions
        return False

    def load_lora_adapter(
        self,
        lora_name: str,
        lora_path: Optional[str] = None,
        expected_error: Optional[str] = None,
        pinned: bool = False,
        expected_implicit_evictions: Optional[set[str]] = None,
    ):
        """
        Load a LoRA adapter by name and path.
        """
        if lora_path is None:
            lora_path = lora_name

        response = self.handle.load_lora_adapter(
            lora_name=lora_name,
            lora_path=lora_path,
            pinned=pinned,
        )
        if expected_error:
            self.testcase.assertFalse(
                response.success, f"Expected failure for {lora_name}, but got success."
            )
            self.testcase.assertIn(
                expected_error,
                response.error_message,
                f"Expected error message to contain '{expected_error}', but got '{response.error_message}'",
            )
            print(f"Received error as expected: {response.error_message}")
        else:
            self.expected_adapters.add(lora_name)
            if expected_implicit_evictions is not None:
                self.expected_adapters -= expected_implicit_evictions

            self.testcase.assertTrue(
                response.success,
                f"Failed to load LoRA adapter {lora_name}: {response.error_message}",
            )
            loaded_adapters = set(response.loaded_adapters)
            print(f"loaded_adapters: {loaded_adapters}")
            self.testcase.assertEqual(
                loaded_adapters,
                self.expected_adapters,
                f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
            )

    def unload_lora_adapter(self, lora_name: str):
        """
        Unload a LoRA adapter by name.
        """
        self.expected_adapters.remove(lora_name)

        response = self.handle.unload_lora_adapter(
            lora_name=lora_name,
        )
        self.testcase.assertTrue(
            response.success,
            f"Failed to unload LoRA adapter {lora_name}: {response.error_message}",
        )
        loaded_adapters = set(response.loaded_adapters)

        print(f"loaded_adapters: {loaded_adapters}")
        self.testcase.assertEqual(
            loaded_adapters,
            self.expected_adapters,
            f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
        )

    def forward(
        self,
        prompts: List[str],
        lora_paths: List[str],
        max_new_tokens: int = 32,
        expected_error: Optional[str] = None,
        expected_implicit_evictions: Optional[set[str]] = None,
    ):
        """
        Perform a batch forward pass with the current set of loaded LoRA adapters.
        """
        try:
            response = self.handle.batch_forward(
                prompts=prompts,
                lora_paths=lora_paths,
                max_new_tokens=max_new_tokens,
            )
        except ValueError as e:
            if expected_error:
                error_message = str(e)
                self.testcase.assertIn(
                    expected_error,
                    error_message,
                    f"Expected error message to contain '{expected_error}', but got '{error_message}'",
                )
                print(f"Received error as expected: {error_message}")
                return error_message

            raise e

        self.testcase.assertEqual(
            len(response.output_strs),
            len(prompts),
            f"Expected {len(prompts)} outputs, but got {len(response.output_strs)}",
        )
        output = response.output_strs
        print(f"output_strs: {output}")

        self.expected_adapters.update(
            [lora_path for lora_path in lora_paths if lora_path is not None]
        )

        if expected_implicit_evictions is not None:
            self.expected_adapters -= expected_implicit_evictions

        return output


class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
    """
    Context manager for testing LoRA adapters with standalone server.
    """

    def __enter__(self):
        other_args = [
            "--cuda-graph-max-bs",
            str(self.cuda_graph_max_bs),
            "--max-loras-per-batch",
            str(self.max_loras_per_batch),
            "--lora-backend",
            self.lora_backend,
            "--random-seed",
            "42",
            "--max-running-request",
            "1",
            "--mem-fraction-static",
            str(MEM_FRACTION_STATIC),
            "--disable-radix-cache",
        ]
        if self.enable_lora:
            other_args.append("--enable-lora")
        if self.lora_paths:
            other_args.append("--lora-paths")
            for lora_path in self.lora_paths:
                if isinstance(lora_path, dict):
                    lora_path = json.dumps(lora_path)
                other_args.append(lora_path)
        if self.disable_cuda_graph:
            other_args.append("--disable-cuda-graph")
        if self.max_lora_rank is not None:
            other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
        if self.lora_target_modules is not None:
            other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
        if self.max_loaded_loras is not None:
            other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)])

        # launch external server
        self.handle = popen_launch_server(
            self.model_path,
            DEFAULT_URL_FOR_TEST,
            DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=other_args,
        )
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.handle is not None:
            kill_process_tree(self.handle.pid)
        # don't suppress exceptions
        return False

    def load_lora_adapter(
        self,
        lora_name: str,
        lora_path: Optional[str] = None,
        expected_error: Optional[str] = None,
        pinned: bool = False,
        expected_implicit_evictions: Optional[set[str]] = None,
    ):
        """
        Load a LoRA adapter by name and path.
        """
        if lora_path is None:
            lora_path = lora_name

        response = requests.post(
            DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
            json={"lora_name": lora_name, "lora_path": lora_path, "pinned": pinned},
        )
        if expected_error:
            self.testcase.assertEqual(
                response.status_code,
                400,
                f"Expected error for {lora_name}, but got success.",
            )
            self.testcase.assertIn(
                expected_error,
                response.text,
                f"Expected error message to contain '{expected_error}', but got '{response.text}'",
            )
            print(f"Received error as expected: {response.text}")
        else:
            self.expected_adapters.add(lora_name)
            if expected_implicit_evictions is not None:
                self.expected_adapters -= expected_implicit_evictions

            self.testcase.assertTrue(
                response.ok, f"Failed to load LoRA adapter {lora_name}: {response.text}"
            )
            loaded_adapters = set(response.json()["loaded_adapters"])
            print(f"loaded_adapters: {loaded_adapters}")
            self.testcase.assertEqual(
                loaded_adapters,
                self.expected_adapters,
                f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
            )

    def unload_lora_adapter(self, lora_name: str):
        """
        Unload a LoRA adapter by name.
        """
        self.expected_adapters.remove(lora_name)

        response = requests.post(
            DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
            json={"lora_name": lora_name},
        )

        self.testcase.assertTrue(
            response.ok, f"Failed to unload LoRA adapter {lora_name}: {response.text}"
        )
        loaded_adapters = set(response.json()["loaded_adapters"])

        print(f"loaded_adapters: {loaded_adapters}")
        self.testcase.assertEqual(
            loaded_adapters,
            self.expected_adapters,
            f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
        )

    def forward(
        self,
        prompts: List[str],
        lora_paths: List[str],
        max_new_tokens: int = 32,
        expected_error: Optional[str] = None,
        expected_implicit_evictions: Optional[set[str]] = None,
    ):
        """
        Perform a batch forward pass with the current set of loaded LoRA adapters.
        """
        response = requests.post(
            DEFAULT_URL_FOR_TEST + "/generate",
            json={
                "text": prompts,
                "lora_path": lora_paths,
                "sampling_params": {
                    "temperature": 0,
                    "top_k": 1,
                    "max_new_tokens": max_new_tokens,
                },
            },
        )
        if expected_error:
            self.testcase.assertEqual(
                response.status_code,
                400,
                f"Expected error for forward pass, but got success: {response.text}",
            )
            self.testcase.assertIn(
                expected_error,
                response.text,
                f"Expected error message to contain '{expected_error}', but got '{response.text}'",
            )
            output = response.text
            print(f"Received error as expected: {response.text}")
            return output
        else:
            self.testcase.assertTrue(
                response.ok, f"Failed to generate text: {response.text}"
            )
            output = [r["text"] for r in response.json()]
            self.testcase.assertEqual(
                len(output),
                len(prompts),
                f"Expected {len(prompts)} outputs, but got {len(output)}",
            )
            print(f"output_strs: {output}")

            self.expected_adapters.update(
                [lora_path for lora_path in lora_paths if lora_path is not None]
            )

            if expected_implicit_evictions is not None:
                self.expected_adapters -= expected_implicit_evictions

            return output


# Factory function to create the appropriate LoRA test session based on mode
def LoRAUpdateTestSession(
    testcase: Optional[TestCase],
    mode: LoRAUpdateTestSessionMode,
    **kwargs: Any,
):
    if mode == LoRAUpdateTestSessionMode.ENGINE:
        return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
    elif mode == LoRAUpdateTestSessionMode.SERVER:
        return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
    else:
        raise ValueError(f"Unrecognized mode: {mode!r}")


class TestLoRADynamicUpdate(CustomTestCase):
    """
    This test case verifies that the SRT runner can dynamically load and unload LoRA adapters
    during a sequence of operations, and that the outputs of forward passes with dynamically loaded
    adapters match the outputs of forward passes with statically loaded adapters.
    """

    def _repeat_each(lst, n):
        return [x for x in lst for _ in range(n)]

    def _run_operation_sequence(
        self,
        mode: LoRAUpdateTestSessionMode,
        base: str,
        initial_adapters: List[Union[str, dict]],
        op_sequence: List[Operation],
        max_loras_per_batch: int,
        max_loaded_loras: Optional[int] = None,
        enable_lora: Optional[bool] = None,
        max_lora_rank: Optional[int] = None,
        lora_target_modules: Optional[List[str]] = None,
        max_new_tokens: int = 32,
    ) -> List[tuple]:
        """
        Runs a sequence of operations on the SRT runner, including loading and unloading LoRA adapters,
        and performing forward passes with the current set of loaded adapters.
        """

        forward_outputs = []
        with LoRAUpdateTestSession(
            testcase=self,
            mode=mode,
            model_path=base,
            lora_paths=initial_adapters,
            max_loras_per_batch=max_loras_per_batch,
            max_loaded_loras=max_loaded_loras,
            max_lora_rank=max_lora_rank,
            lora_target_modules=lora_target_modules,
            enable_lora=enable_lora,
        ) as session:
            for op in op_sequence:
                op_type = op.type
                data = op.data
                expected_error = op.expected_error
                expected_implicit_evictions = op.expected_implicit_evictions
                print("-" * 100)
                print(
                    f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
                )
                if op_type == OperationType.LOAD:
                    if isinstance(data, str):
                        adapter_info = {
                            "lora_name": data,
                            "lora_path": data,
                            "pinned": False,
                        }
                    else:
                        adapter_info = data

                    result = session.load_lora_adapter(
                        expected_error=expected_error,
                        expected_implicit_evictions=expected_implicit_evictions,
                        **adapter_info,
                    )
                elif op_type == OperationType.UNLOAD:
                    result = session.unload_lora_adapter(
                        lora_name=data,
                    )
                elif op_type == OperationType.FORWARD:
                    prompts, adapters = zip(*data)
                    result = session.forward(
                        prompts=list(prompts),
                        lora_paths=list(adapters),
                        max_new_tokens=max_new_tokens,
                        expected_error=expected_error,
                        expected_implicit_evictions=expected_implicit_evictions,
                    )
                    if not expected_error:
                        forward_outputs.append(result)

            return forward_outputs

    def _run_dynamic_adapter_updates(
        self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
    ):
        for case_idx, test_case in enumerate(test_cases, start=1):
            print("=" * 100)
            print(
                f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
            )
            print("=" * 100)

            print(
                f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
            )
            # Test dynamic loading of adapters
            dynamic_output = self._run_operation_sequence(
                mode=mode,
                initial_adapters=test_case.initial_adapters,
                enable_lora=test_case.enable_lora,
                base=test_case.base,
                max_loras_per_batch=test_case.max_loras_per_batch,
                max_loaded_loras=test_case.max_loaded_loras,
                op_sequence=test_case.op_sequence,
                max_new_tokens=test_case.max_new_tokens,
                max_lora_rank=test_case.max_lora_rank,
                lora_target_modules=test_case.lora_target_modules,
            )

            # static loading
            forward_ops = [
                x
                for x in test_case.op_sequence
                if x.type == OperationType.FORWARD and x.expected_error is None
            ]

            if not forward_ops:
                print(
                    f"No forward operations found in test case {case_idx}. Skipping static pass."
                )
                continue

            print("=" * 100)
            print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
            static_output = self._run_operation_sequence(
                mode=mode,
                initial_adapters=test_case.all_adapters,
                enable_lora=test_case.enable_lora,
                base=test_case.base,
                max_loras_per_batch=test_case.max_loras_per_batch,
                op_sequence=forward_ops,
                max_new_tokens=test_case.max_new_tokens,
            )

            ROUGE_L_TOL = 0.9

            print(f"Dynamic output: {dynamic_output}")
            print(f"Static output: {static_output}")
            print("=" * 100)
            self.assertEqual(
                len(dynamic_output),
                len(static_output),
                f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
            )
            for i, (dynamic, static) in enumerate(
                zip(dynamic_output, static_output), start=1
            ):
                self.assertEqual(
                    len(dynamic),
                    len(static),
                    f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
                )
                for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
                    d_out_str = d_out.strip()
                    s_out_str = s_out.strip()
                    rouge_score = calculate_rouge_l([d_out_str], [s_out_str])[0]

                    self.assertGreaterEqual(
                        rouge_score,
                        ROUGE_L_TOL,
                        f"ROUGE-L score {rouge_score} of outputs is below tolerance of {ROUGE_L_TOL} "
                        f"at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
                    )

    def test_dynamic_lora_update_engine(self):
        """
        Test dynamic LoRA updates in engine mode.
        """
        test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
        self._run_dynamic_adapter_updates(
            mode=LoRAUpdateTestSessionMode.ENGINE,
            test_cases=test_cases,
        )

    def test_dynamic_lora_update_server(self):
        """
        Test dynamic LoRA updates in server mode.
        """
        test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
        self._run_dynamic_adapter_updates(
            mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
        )

    def test_v1_models_endpoint_with_lora(self):
        """
        Test that /v1/models endpoint returns base model and loaded LoRA adapters.
        """
        adapters = [
            "philschmid/code-llama-3-1-8b-text-to-sql-lora",
            "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
        ]

        with LoRAUpdateTestSession(
            testcase=self,
            mode=LoRAUpdateTestSessionMode.SERVER,
            model_path="meta-llama/Llama-3.1-8B-Instruct",
            lora_paths=[],
            max_loras_per_batch=2,
            max_lora_rank=256,
            lora_target_modules=["all"],
            enable_lora=True,
        ) as session:
            # Test with no adapters loaded
            response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
            self.assertTrue(response.ok, response.text)
            models_data = response.json()
            self.assertEqual(models_data["object"], "list")
            self.assertEqual(len(models_data["data"]), 1)  # Only base model
            base_model = models_data["data"][0]
            self.assertIn("meta-llama", base_model["id"].lower())
            self.assertIsNone(base_model.get("parent"))

            # Load first adapter
            session.load_lora_adapter(lora_name="adapter1", lora_path=adapters[0])

            # Test with one adapter loaded
            response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
            self.assertTrue(response.ok, response.text)
            models_data = response.json()
            self.assertEqual(len(models_data["data"]), 2)  # Base model + 1 adapter

            # Verify adapter information
            adapter_models = [m for m in models_data["data"] if m.get("parent")]
            self.assertEqual(len(adapter_models), 1)
            self.assertEqual(adapter_models[0]["id"], "adapter1")
            self.assertEqual(adapter_models[0]["root"], adapters[0])
            self.assertIsNotNone(adapter_models[0]["parent"])

            # Load second adapter
            session.load_lora_adapter(lora_name="adapter2", lora_path=adapters[1])

            # Test with two adapters loaded
            response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
            self.assertTrue(response.ok, response.text)
            models_data = response.json()
            self.assertEqual(len(models_data["data"]), 3)  # Base model + 2 adapters

            # Verify both adapters are listed
            adapter_models = [m for m in models_data["data"] if m.get("parent")]
            self.assertEqual(len(adapter_models), 2)
            adapter_names = {m["id"] for m in adapter_models}
            self.assertEqual(adapter_names, {"adapter1", "adapter2"})

            # Unload one adapter
            session.unload_lora_adapter(lora_name="adapter1")

            # Test after unloading
            response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
            self.assertTrue(response.ok, response.text)
            models_data = response.json()
            self.assertEqual(len(models_data["data"]), 2)  # Base model + 1 adapter
            adapter_models = [m for m in models_data["data"] if m.get("parent")]
            self.assertEqual(len(adapter_models), 1)
            self.assertEqual(adapter_models[0]["id"], "adapter2")


if __name__ == "__main__":
    try:
        mp.set_start_method("spawn")
    except RuntimeError:
        pass

    unittest.main(warnings="ignore")
